{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Networks\n",
    "*Complete and hand in this completed worksheet (including its outputs and any supporting code outside of the worksheet) with your assignment submission. Please check the pdf file for more details.*\n",
    "\n",
    "In this exercise you will:\n",
    "    \n",
    "- implement the **forward** and **backward** operations for different layers in neural networks\n",
    "- implement a simple neural networks for classification\n",
    "\n",
    "Please note that **YOU CANNOT USE ANY MACHINE LEARNING PACKAGE SUCH AS SKLEARN** for any homework, unless you are asked."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# some basic imports\n",
    "import scipy.io as sio\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "digit_data = sio.loadmat('digit_data.mat')\n",
    "X = digit_data['X']\n",
    "y = digit_data['y']\n",
    "_, num_cases = X.shape\n",
    "train_num_cases = num_cases * 4 // 5\n",
    "X = X.reshape((400, num_cases))\n",
    "X = X.transpose()\n",
    "# X has the shape of (number of samples, number of pixels)\n",
    "train_data = X[:train_num_cases,:]\n",
    "train_label = y[:, :train_num_cases]\n",
    "test_data = X[train_num_cases:, :]\n",
    "test_label = y[:, train_num_cases:]\n",
    "weights = {}\n",
    "weights['fully1_weight'] = np.random.randn(400, 25) / 400\n",
    "weights['fully1_bias'] = np.random.rand(25, 1) \n",
    "weights['fully2_weight'] = np.random.randn(25, 10) / 25\n",
    "weights['fully2_bias'] = np.random.rand(10, 1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  1. 1 loss:2.41, accuracy:0.12\n",
      "  1. 2 loss:2.38, accuracy:0.09\n",
      "  1. 3 loss:2.32, accuracy:0.09\n",
      "  1. 4 loss:2.29, accuracy:0.15\n",
      "  1. 5 loss:2.28, accuracy:0.11\n",
      "  1. 6 loss:2.27, accuracy:0.1\n",
      "  1. 7 loss:2.26, accuracy:0.08\n",
      "  1. 8 loss:2.24, accuracy:0.1\n",
      "  1. 9 loss:2.26, accuracy:0.19\n",
      "  1.10 loss:2.19, accuracy:0.31\n",
      "  1.11 loss:2.19, accuracy:0.2\n",
      "  1.12 loss:2.12, accuracy:0.22\n",
      "  1.13 loss:2.09, accuracy:0.28\n",
      "  1.14 loss:2.07, accuracy:0.29\n",
      "  1.15 loss:1.96, accuracy:0.38\n",
      "  1.16 loss:1.9, accuracy:0.49\n",
      "  1.17 loss:1.84, accuracy:0.47\n",
      "  1.18 loss:1.79, accuracy:0.5\n",
      "  1.19 loss:1.75, accuracy:0.42\n",
      "  1.20 loss:1.52, accuracy:0.6\n",
      "  1.21 loss:1.45, accuracy:0.62\n",
      "  1.22 loss:1.41, accuracy:0.59\n",
      "  1.23 loss:1.27, accuracy:0.6\n",
      "  1.24 loss:1.12, accuracy:0.71\n",
      "  1.25 loss:1.1, accuracy:0.75\n",
      "  1.26 loss:1.01, accuracy:0.75\n",
      "  1.27 loss:1.07, accuracy:0.67\n",
      "  1.28 loss:0.964, accuracy:0.69\n",
      "  1.29 loss:0.801, accuracy:0.81\n",
      "  1.30 loss:0.855, accuracy:0.74\n",
      "  1.31 loss:0.825, accuracy:0.69\n",
      "  1.32 loss:0.818, accuracy:0.68\n",
      "  1.33 loss:0.674, accuracy:0.72\n",
      "  1.34 loss:0.775, accuracy:0.72\n",
      "  1.35 loss:0.877, accuracy:0.76\n",
      "  1.36 loss:0.69, accuracy:0.81\n",
      "  1.37 loss:0.615, accuracy:0.8\n",
      "  1.38 loss:0.604, accuracy:0.8\n",
      "  1.39 loss:0.665, accuracy:0.82\n",
      "  1.40 loss:0.678, accuracy:0.79\n",
      "  2. 1 loss:0.657, accuracy:0.73\n",
      "  2. 2 loss:0.685, accuracy:0.79\n",
      "  2. 3 loss:0.513, accuracy:0.82\n",
      "  2. 4 loss:0.686, accuracy:0.79\n",
      "  2. 5 loss:0.712, accuracy:0.73\n",
      "  2. 6 loss:0.42, accuracy:0.84\n",
      "  2. 7 loss:0.638, accuracy:0.81\n",
      "  2. 8 loss:0.412, accuracy:0.86\n",
      "  2. 9 loss:0.471, accuracy:0.84\n",
      "  2.10 loss:0.656, accuracy:0.79\n",
      "  2.11 loss:0.656, accuracy:0.8\n",
      "  2.12 loss:0.856, accuracy:0.79\n",
      "  2.13 loss:0.623, accuracy:0.74\n",
      "  2.14 loss:0.686, accuracy:0.75\n",
      "  2.15 loss:0.302, accuracy:0.94\n",
      "  2.16 loss:0.451, accuracy:0.87\n",
      "  2.17 loss:0.671, accuracy:0.77\n",
      "  2.18 loss:0.638, accuracy:0.84\n",
      "  2.19 loss:0.584, accuracy:0.82\n",
      "  2.20 loss:0.441, accuracy:0.86\n",
      "  2.21 loss:0.42, accuracy:0.89\n",
      "  2.22 loss:0.647, accuracy:0.85\n",
      "  2.23 loss:0.494, accuracy:0.84\n",
      "  2.24 loss:0.446, accuracy:0.85\n",
      "  2.25 loss:0.355, accuracy:0.91\n",
      "  2.26 loss:0.395, accuracy:0.91\n",
      "  2.27 loss:0.649, accuracy:0.78\n",
      "  2.28 loss:0.512, accuracy:0.87\n",
      "  2.29 loss:0.396, accuracy:0.87\n",
      "  2.30 loss:0.384, accuracy:0.87\n",
      "  2.31 loss:0.515, accuracy:0.88\n",
      "  2.32 loss:0.462, accuracy:0.82\n",
      "  2.33 loss:0.482, accuracy:0.85\n",
      "  2.34 loss:0.475, accuracy:0.84\n",
      "  2.35 loss:0.526, accuracy:0.86\n",
      "  2.36 loss:0.387, accuracy:0.89\n",
      "  2.37 loss:0.333, accuracy:0.92\n",
      "  2.38 loss:0.374, accuracy:0.91\n",
      "  2.39 loss:0.368, accuracy:0.91\n",
      "  2.40 loss:0.44, accuracy:0.9\n",
      "  3. 1 loss:0.32, accuracy:0.9\n",
      "  3. 2 loss:0.477, accuracy:0.89\n",
      "  3. 3 loss:0.275, accuracy:0.93\n",
      "  3. 4 loss:0.434, accuracy:0.88\n",
      "  3. 5 loss:0.325, accuracy:0.9\n",
      "  3. 6 loss:0.201, accuracy:0.97\n",
      "  3. 7 loss:0.418, accuracy:0.88\n",
      "  3. 8 loss:0.277, accuracy:0.92\n",
      "  3. 9 loss:0.215, accuracy:0.93\n",
      "  3.10 loss:0.403, accuracy:0.92\n",
      "  3.11 loss:0.301, accuracy:0.91\n",
      "  3.12 loss:0.522, accuracy:0.86\n",
      "  3.13 loss:0.48, accuracy:0.82\n",
      "  3.14 loss:0.518, accuracy:0.85\n",
      "  3.15 loss:0.222, accuracy:0.92\n",
      "  3.16 loss:0.272, accuracy:0.94\n",
      "  3.17 loss:0.353, accuracy:0.87\n",
      "  3.18 loss:0.515, accuracy:0.84\n",
      "  3.19 loss:0.487, accuracy:0.84\n",
      "  3.20 loss:0.272, accuracy:0.92\n",
      "  3.21 loss:0.293, accuracy:0.91\n",
      "  3.22 loss:0.347, accuracy:0.9\n",
      "  3.23 loss:0.301, accuracy:0.91\n",
      "  3.24 loss:0.294, accuracy:0.93\n",
      "  3.25 loss:0.236, accuracy:0.96\n",
      "  3.26 loss:0.286, accuracy:0.9\n",
      "  3.27 loss:0.44, accuracy:0.84\n",
      "  3.28 loss:0.316, accuracy:0.93\n",
      "  3.29 loss:0.289, accuracy:0.91\n",
      "  3.30 loss:0.241, accuracy:0.94\n",
      "  3.31 loss:0.367, accuracy:0.89\n",
      "  3.32 loss:0.228, accuracy:0.93\n",
      "  3.33 loss:0.303, accuracy:0.87\n",
      "  3.34 loss:0.36, accuracy:0.9\n",
      "  3.35 loss:0.41, accuracy:0.88\n",
      "  3.36 loss:0.295, accuracy:0.9\n",
      "  3.37 loss:0.247, accuracy:0.92\n",
      "  3.38 loss:0.294, accuracy:0.94\n",
      "  3.39 loss:0.278, accuracy:0.91\n",
      "  3.40 loss:0.435, accuracy:0.91\n",
      "  4. 1 loss:0.232, accuracy:0.94\n",
      "  4. 2 loss:0.316, accuracy:0.91\n",
      "  4. 3 loss:0.215, accuracy:0.95\n",
      "  4. 4 loss:0.34, accuracy:0.91\n",
      "  4. 5 loss:0.274, accuracy:0.92\n",
      "  4. 6 loss:0.156, accuracy:0.99\n",
      "  4. 7 loss:0.404, accuracy:0.88\n",
      "  4. 8 loss:0.226, accuracy:0.96\n",
      "  4. 9 loss:0.196, accuracy:0.96\n",
      "  4.10 loss:0.323, accuracy:0.89\n",
      "  4.11 loss:0.257, accuracy:0.89\n",
      "  4.12 loss:0.392, accuracy:0.91\n",
      "  4.13 loss:0.361, accuracy:0.87\n",
      "  4.14 loss:0.444, accuracy:0.89\n",
      "  4.15 loss:0.181, accuracy:0.94\n",
      "  4.16 loss:0.268, accuracy:0.92\n",
      "  4.17 loss:0.254, accuracy:0.91\n",
      "  4.18 loss:0.358, accuracy:0.91\n",
      "  4.19 loss:0.402, accuracy:0.87\n",
      "  4.20 loss:0.225, accuracy:0.92\n",
      "  4.21 loss:0.254, accuracy:0.92\n",
      "  4.22 loss:0.268, accuracy:0.9\n",
      "  4.23 loss:0.257, accuracy:0.93\n",
      "  4.24 loss:0.239, accuracy:0.92\n",
      "  4.25 loss:0.173, accuracy:0.96\n",
      "  4.26 loss:0.213, accuracy:0.95\n",
      "  4.27 loss:0.364, accuracy:0.88\n",
      "  4.28 loss:0.256, accuracy:0.93\n",
      "  4.29 loss:0.228, accuracy:0.94\n",
      "  4.30 loss:0.188, accuracy:0.96\n",
      "  4.31 loss:0.276, accuracy:0.92\n",
      "  4.32 loss:0.227, accuracy:0.92\n",
      "  4.33 loss:0.265, accuracy:0.89\n",
      "  4.34 loss:0.279, accuracy:0.93\n",
      "  4.35 loss:0.365, accuracy:0.91\n",
      "  4.36 loss:0.289, accuracy:0.92\n",
      "  4.37 loss:0.204, accuracy:0.92\n",
      "  4.38 loss:0.263, accuracy:0.94\n",
      "  4.39 loss:0.231, accuracy:0.94\n",
      "  4.40 loss:0.353, accuracy:0.91\n",
      "  5. 1 loss:0.174, accuracy:0.94\n",
      "  5. 2 loss:0.252, accuracy:0.93\n",
      "  5. 3 loss:0.192, accuracy:0.96\n",
      "  5. 4 loss:0.283, accuracy:0.93\n",
      "  5. 5 loss:0.268, accuracy:0.93\n",
      "  5. 6 loss:0.135, accuracy:0.98\n",
      "  5. 7 loss:0.324, accuracy:0.9\n",
      "  5. 8 loss:0.163, accuracy:0.97\n",
      "  5. 9 loss:0.166, accuracy:0.96\n",
      "  5.10 loss:0.283, accuracy:0.91\n",
      "  5.11 loss:0.259, accuracy:0.89\n",
      "  5.12 loss:0.345, accuracy:0.92\n",
      "  5.13 loss:0.292, accuracy:0.89\n",
      "  5.14 loss:0.41, accuracy:0.91\n",
      "  5.15 loss:0.142, accuracy:0.96\n",
      "  5.16 loss:0.248, accuracy:0.92\n",
      "  5.17 loss:0.207, accuracy:0.94\n",
      "  5.18 loss:0.306, accuracy:0.93\n",
      "  5.19 loss:0.389, accuracy:0.89\n",
      "  5.20 loss:0.221, accuracy:0.94\n",
      "  5.21 loss:0.205, accuracy:0.94\n",
      "  5.22 loss:0.231, accuracy:0.92\n",
      "  5.23 loss:0.237, accuracy:0.94\n",
      "  5.24 loss:0.211, accuracy:0.92\n",
      "  5.25 loss:0.16, accuracy:0.96\n",
      "  5.26 loss:0.177, accuracy:0.94\n",
      "  5.27 loss:0.299, accuracy:0.92\n",
      "  5.28 loss:0.21, accuracy:0.96\n",
      "  5.29 loss:0.194, accuracy:0.95\n",
      "  5.30 loss:0.168, accuracy:0.97\n",
      "  5.31 loss:0.237, accuracy:0.94\n",
      "  5.32 loss:0.194, accuracy:0.94\n",
      "  5.33 loss:0.243, accuracy:0.91\n",
      "  5.34 loss:0.269, accuracy:0.92\n",
      "  5.35 loss:0.343, accuracy:0.9\n",
      "  5.36 loss:0.245, accuracy:0.94\n",
      "  5.37 loss:0.18, accuracy:0.91\n",
      "  5.38 loss:0.231, accuracy:0.94\n",
      "  5.39 loss:0.22, accuracy:0.93\n",
      "  5.40 loss:0.309, accuracy:0.92\n",
      "  6. 1 loss:0.152, accuracy:0.96\n",
      "  6. 2 loss:0.224, accuracy:0.93\n",
      "  6. 3 loss:0.167, accuracy:0.96\n",
      "  6. 4 loss:0.225, accuracy:0.95\n",
      "  6. 5 loss:0.235, accuracy:0.94\n",
      "  6. 6 loss:0.142, accuracy:0.97\n",
      "  6. 7 loss:0.306, accuracy:0.91\n",
      "  6. 8 loss:0.14, accuracy:0.98\n",
      "  6. 9 loss:0.131, accuracy:0.97\n",
      "  6.10 loss:0.218, accuracy:0.93\n",
      "  6.11 loss:0.192, accuracy:0.92\n",
      "  6.12 loss:0.295, accuracy:0.94\n",
      "  6.13 loss:0.254, accuracy:0.89\n",
      "  6.14 loss:0.386, accuracy:0.91\n",
      "  6.15 loss:0.107, accuracy:0.96\n",
      "  6.16 loss:0.194, accuracy:0.95\n",
      "  6.17 loss:0.174, accuracy:0.94\n",
      "  6.18 loss:0.28, accuracy:0.93\n",
      "  6.19 loss:0.363, accuracy:0.89\n",
      "  6.20 loss:0.189, accuracy:0.95\n",
      "  6.21 loss:0.171, accuracy:0.95\n",
      "  6.22 loss:0.187, accuracy:0.93\n",
      "  6.23 loss:0.205, accuracy:0.94\n",
      "  6.24 loss:0.19, accuracy:0.92\n",
      "  6.25 loss:0.134, accuracy:0.97\n",
      "  6.26 loss:0.161, accuracy:0.95\n",
      "  6.27 loss:0.28, accuracy:0.93\n",
      "  6.28 loss:0.192, accuracy:0.98\n",
      "  6.29 loss:0.183, accuracy:0.95\n",
      "  6.30 loss:0.146, accuracy:0.96\n",
      "  6.31 loss:0.214, accuracy:0.95\n",
      "  6.32 loss:0.155, accuracy:0.96\n",
      "  6.33 loss:0.205, accuracy:0.94\n",
      "  6.34 loss:0.255, accuracy:0.92\n",
      "  6.35 loss:0.319, accuracy:0.9\n",
      "  6.36 loss:0.224, accuracy:0.95\n",
      "  6.37 loss:0.155, accuracy:0.95\n",
      "  6.38 loss:0.204, accuracy:0.94\n",
      "  6.39 loss:0.219, accuracy:0.91\n",
      "  6.40 loss:0.289, accuracy:0.91\n",
      "  7. 1 loss:0.124, accuracy:0.96\n",
      "  7. 2 loss:0.191, accuracy:0.94\n",
      "  7. 3 loss:0.149, accuracy:0.96\n",
      "  7. 4 loss:0.199, accuracy:0.96\n",
      "  7. 5 loss:0.201, accuracy:0.95\n",
      "  7. 6 loss:0.142, accuracy:0.97\n",
      "  7. 7 loss:0.287, accuracy:0.9\n",
      "  7. 8 loss:0.143, accuracy:0.96\n",
      "  7. 9 loss:0.121, accuracy:0.97\n",
      "  7.10 loss:0.185, accuracy:0.95\n",
      "  7.11 loss:0.152, accuracy:0.95\n",
      "  7.12 loss:0.237, accuracy:0.95\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  7.13 loss:0.202, accuracy:0.91\n",
      "  7.14 loss:0.355, accuracy:0.92\n",
      "  7.15 loss:0.0913, accuracy:0.96\n",
      "  7.16 loss:0.171, accuracy:0.96\n",
      "  7.17 loss:0.157, accuracy:0.95\n",
      "  7.18 loss:0.254, accuracy:0.94\n",
      "  7.19 loss:0.316, accuracy:0.88\n",
      "  7.20 loss:0.165, accuracy:0.95\n",
      "  7.21 loss:0.157, accuracy:0.95\n",
      "  7.22 loss:0.15, accuracy:0.94\n",
      "  7.23 loss:0.183, accuracy:0.98\n",
      "  7.24 loss:0.167, accuracy:0.94\n",
      "  7.25 loss:0.117, accuracy:0.98\n",
      "  7.26 loss:0.155, accuracy:0.97\n",
      "  7.27 loss:0.259, accuracy:0.91\n",
      "  7.28 loss:0.176, accuracy:0.98\n",
      "  7.29 loss:0.178, accuracy:0.94\n",
      "  7.30 loss:0.141, accuracy:0.97\n",
      "  7.31 loss:0.201, accuracy:0.95\n",
      "  7.32 loss:0.133, accuracy:0.98\n",
      "  7.33 loss:0.166, accuracy:0.95\n",
      "  7.34 loss:0.22, accuracy:0.94\n",
      "  7.35 loss:0.283, accuracy:0.91\n",
      "  7.36 loss:0.204, accuracy:0.96\n",
      "  7.37 loss:0.146, accuracy:0.95\n",
      "  7.38 loss:0.182, accuracy:0.95\n",
      "  7.39 loss:0.213, accuracy:0.92\n",
      "  7.40 loss:0.274, accuracy:0.91\n",
      "  8. 1 loss:0.116, accuracy:0.96\n",
      "  8. 2 loss:0.154, accuracy:0.97\n",
      "  8. 3 loss:0.136, accuracy:0.96\n",
      "  8. 4 loss:0.178, accuracy:0.96\n",
      "  8. 5 loss:0.179, accuracy:0.95\n",
      "  8. 6 loss:0.13, accuracy:0.97\n",
      "  8. 7 loss:0.248, accuracy:0.92\n",
      "  8. 8 loss:0.126, accuracy:0.96\n",
      "  8. 9 loss:0.105, accuracy:0.97\n",
      "  8.10 loss:0.155, accuracy:0.95\n",
      "  8.11 loss:0.137, accuracy:0.94\n",
      "  8.12 loss:0.204, accuracy:0.96\n",
      "  8.13 loss:0.164, accuracy:0.93\n",
      "  8.14 loss:0.291, accuracy:0.92\n",
      "  8.15 loss:0.0701, accuracy:0.98\n",
      "  8.16 loss:0.156, accuracy:0.96\n",
      "  8.17 loss:0.164, accuracy:0.95\n",
      "  8.18 loss:0.251, accuracy:0.95\n",
      "  8.19 loss:0.289, accuracy:0.91\n",
      "  8.20 loss:0.156, accuracy:0.95\n",
      "  8.21 loss:0.135, accuracy:0.95\n",
      "  8.22 loss:0.136, accuracy:0.98\n",
      "  8.23 loss:0.167, accuracy:0.97\n",
      "  8.24 loss:0.156, accuracy:0.95\n",
      "  8.25 loss:0.115, accuracy:0.97\n",
      "  8.26 loss:0.141, accuracy:0.97\n",
      "  8.27 loss:0.228, accuracy:0.92\n",
      "  8.28 loss:0.159, accuracy:0.98\n",
      "  8.29 loss:0.17, accuracy:0.97\n",
      "  8.30 loss:0.137, accuracy:0.97\n",
      "  8.31 loss:0.188, accuracy:0.95\n",
      "  8.32 loss:0.119, accuracy:0.98\n",
      "  8.33 loss:0.142, accuracy:0.96\n",
      "  8.34 loss:0.2, accuracy:0.95\n",
      "  8.35 loss:0.274, accuracy:0.93\n",
      "  8.36 loss:0.189, accuracy:0.95\n",
      "  8.37 loss:0.14, accuracy:0.96\n",
      "  8.38 loss:0.151, accuracy:0.96\n",
      "  8.39 loss:0.21, accuracy:0.92\n",
      "  8.40 loss:0.248, accuracy:0.92\n",
      "  9. 1 loss:0.107, accuracy:0.96\n",
      "  9. 2 loss:0.146, accuracy:0.97\n",
      "  9. 3 loss:0.141, accuracy:0.96\n",
      "  9. 4 loss:0.167, accuracy:0.98\n",
      "  9. 5 loss:0.166, accuracy:0.96\n",
      "  9. 6 loss:0.101, accuracy:0.97\n",
      "  9. 7 loss:0.213, accuracy:0.92\n",
      "  9. 8 loss:0.116, accuracy:0.96\n",
      "  9. 9 loss:0.0986, accuracy:0.97\n",
      "  9.10 loss:0.152, accuracy:0.95\n",
      "  9.11 loss:0.12, accuracy:0.95\n",
      "  9.12 loss:0.188, accuracy:0.96\n",
      "  9.13 loss:0.158, accuracy:0.94\n",
      "  9.14 loss:0.251, accuracy:0.93\n",
      "  9.15 loss:0.0618, accuracy:0.98\n",
      "  9.16 loss:0.147, accuracy:0.96\n",
      "  9.17 loss:0.172, accuracy:0.95\n",
      "  9.18 loss:0.247, accuracy:0.95\n",
      "  9.19 loss:0.268, accuracy:0.91\n",
      "  9.20 loss:0.155, accuracy:0.95\n",
      "  9.21 loss:0.118, accuracy:0.96\n",
      "  9.22 loss:0.124, accuracy:0.97\n",
      "  9.23 loss:0.153, accuracy:0.97\n",
      "  9.24 loss:0.149, accuracy:0.95\n",
      "  9.25 loss:0.116, accuracy:0.95\n",
      "  9.26 loss:0.124, accuracy:0.97\n",
      "  9.27 loss:0.194, accuracy:0.94\n",
      "  9.28 loss:0.145, accuracy:0.97\n",
      "  9.29 loss:0.162, accuracy:0.95\n",
      "  9.30 loss:0.127, accuracy:0.98\n",
      "  9.31 loss:0.175, accuracy:0.96\n",
      "  9.32 loss:0.107, accuracy:0.96\n",
      "  9.33 loss:0.135, accuracy:0.95\n",
      "  9.34 loss:0.188, accuracy:0.95\n",
      "  9.35 loss:0.27, accuracy:0.92\n",
      "  9.36 loss:0.168, accuracy:0.95\n",
      "  9.37 loss:0.126, accuracy:0.96\n",
      "  9.38 loss:0.128, accuracy:0.97\n",
      "  9.39 loss:0.183, accuracy:0.94\n",
      "  9.40 loss:0.211, accuracy:0.92\n",
      " 10. 1 loss:0.0863, accuracy:0.97\n",
      " 10. 2 loss:0.139, accuracy:0.97\n",
      " 10. 3 loss:0.119, accuracy:0.95\n",
      " 10. 4 loss:0.161, accuracy:0.98\n",
      " 10. 5 loss:0.172, accuracy:0.95\n",
      " 10. 6 loss:0.0811, accuracy:0.98\n",
      " 10. 7 loss:0.179, accuracy:0.95\n",
      " 10. 8 loss:0.0968, accuracy:0.98\n",
      " 10. 9 loss:0.0876, accuracy:0.99\n",
      " 10.10 loss:0.154, accuracy:0.95\n",
      " 10.11 loss:0.0987, accuracy:0.95\n",
      " 10.12 loss:0.173, accuracy:0.95\n",
      " 10.13 loss:0.163, accuracy:0.95\n",
      " 10.14 loss:0.224, accuracy:0.93\n",
      " 10.15 loss:0.0583, accuracy:0.98\n",
      " 10.16 loss:0.13, accuracy:0.95\n",
      " 10.17 loss:0.166, accuracy:0.93\n",
      " 10.18 loss:0.238, accuracy:0.94\n",
      " 10.19 loss:0.244, accuracy:0.92\n",
      " 10.20 loss:0.128, accuracy:0.95\n",
      " 10.21 loss:0.0997, accuracy:0.97\n",
      " 10.22 loss:0.106, accuracy:0.97\n",
      " 10.23 loss:0.144, accuracy:0.97\n",
      " 10.24 loss:0.15, accuracy:0.94\n",
      " 10.25 loss:0.121, accuracy:0.95\n",
      " 10.26 loss:0.113, accuracy:0.98\n",
      " 10.27 loss:0.19, accuracy:0.93\n",
      " 10.28 loss:0.131, accuracy:0.97\n",
      " 10.29 loss:0.142, accuracy:0.96\n",
      " 10.30 loss:0.102, accuracy:0.98\n",
      " 10.31 loss:0.158, accuracy:0.95\n",
      " 10.32 loss:0.102, accuracy:0.97\n",
      " 10.33 loss:0.124, accuracy:0.95\n",
      " 10.34 loss:0.18, accuracy:0.96\n",
      " 10.35 loss:0.262, accuracy:0.92\n",
      " 10.36 loss:0.141, accuracy:0.97\n",
      " 10.37 loss:0.111, accuracy:0.97\n",
      " 10.38 loss:0.119, accuracy:0.97\n",
      " 10.39 loss:0.15, accuracy:0.96\n",
      " 10.40 loss:0.173, accuracy:0.94\n"
     ]
    }
   ],
   "source": [
    "# training setting\n",
    "weight_inc = {}\n",
    "for name in ('fully1_weight', 'fully1_bias', 'fully2_weight', 'fully2_bias'):\n",
    "    weight_inc[name] = np.zeros(weights[name].shape)\n",
    "batch_size = 100\n",
    "max_epoch = 10\n",
    "momW = 0.9\n",
    "wc = 0.0005\n",
    "learning_rate = 0.1\n",
    "\n",
    "# Training iterations\n",
    "from get_new_weight_inc import get_new_weight_inc\n",
    "from feedforward_backprop import feedforward_backprop\n",
    "\n",
    "for epoch in range(max_epoch):\n",
    "    for i in range(math.ceil(train_num_cases/batch_size)):\n",
    "        data = train_data[i * batch_size:min((i + 1) * batch_size, train_num_cases), :]\n",
    "        label = train_label[:, i * batch_size:min((i + 1) * batch_size, train_num_cases)]\n",
    "        # The feedforward and backpropgation processes\n",
    "        loss, accuracy, gradients = feedforward_backprop(data, label, weights)\n",
    "        print('{:3}.{:2} loss:{:.3}, accuracy:{}'.format(epoch + 1, i + 1, loss, accuracy))\n",
    "        # Updating weights\n",
    "        for name in ('fully1_weight', 'fully1_bias', 'fully2_weight', 'fully2_bias'):\n",
    "            weight_inc[name] = get_new_weight_inc(weight_inc[name], weights[name], momW, wc, learning_rate, gradients[name + '_grad'])\n",
    "            weights[name] += weight_inc[name]\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss:0.288, accuracy:0.918\n"
     ]
    }
   ],
   "source": [
    "# TODO\n",
    "loss, accuracy, _ = feedforward_backprop(test_data, test_label, weights)\n",
    "print('loss:{:.3}, accuracy:{}'.format(loss, accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
